#pragma once
#include <iostream>
#include <string>
#include <cstring>
#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>

#define CONVERT(addr) ((struct sockaddr *)addr)

namespace Net_Work
{
    const int default_backlog = 5;
    const int default_sockfd = -1;

    enum
    {
        SocketError = 1,
        BindError,
        ListenError
    };

    // 设计模式，模板方法类
    class Socket
    {
    public:
        virtual void CreateSocketOrDie() = 0;
        virtual void BindSocketOrDie(uint16_t port) = 0;
        virtual void ListenSocketOrDie(int backlog = default_backlog) = 0;
        virtual Socket *AcceptSocket(std::string *clinet_ip, uint16_t *cilent_port) = 0;
        virtual bool ConnectSocket(const std::string &serverip, uint16_t serverport) = 0;
        virtual void SetSockfd(int sockfd) = 0;
        virtual int GetSockfd() = 0;
        virtual void ColseSockfd() = 0;
        virtual void Send(std::string &send_str) = 0;
        virtual bool Recv(std::string *buffer, int size) = 0;

    public:
        void BUlidListenSockfdMethod(uint16_t port, int backlog = default_backlog)
        {
            CreateSocketOrDie();
            BindSocketOrDie(port);
            ListenSocketOrDie(backlog);
        }

        bool BuildConnectSocketMethod(const std::string &serverip, uint16_t serverport)
        {
            CreateSocketOrDie();
            return ConnectSocket(serverip, serverport);
        }

        void BulidNormalMethod(int sockfd)
        {
            SetSockfd(sockfd);
        }
    };

    class TcpSokcet : public Socket
    {
    public:
        TcpSokcet(int sockfd = default_sockfd) : _sockfd(sockfd)
        {
        }
        ~TcpSokcet()
        {
        }
        void CreateSocketOrDie() override
        {
            _sockfd = ::socket(AF_INET, SOCK_STREAM, 0);
            if (_sockfd < 0)
            {
                exit(SocketError);
            }
        }

        void BindSocketOrDie(uint16_t port) override
        {
            struct sockaddr_in local;
            memset(&local, 0, sizeof(local));
            local.sin_family = AF_INET;
            local.sin_port = htons(port);
            local.sin_addr.s_addr = INADDR_ANY;

            int n = ::bind(_sockfd, CONVERT(&local), sizeof(local));
            if (n < 0)
            {
                exit(BindError);
            }
        }

        void ListenSocketOrDie(int backlog = default_backlog) override
        {
            if (::listen(_sockfd, backlog) < 0)
            {
                exit(ListenError);
            }
        }

        Socket *AcceptSocket(std::string *clinet_ip, uint16_t *cilent_port)
        {
            struct sockaddr_in peer;
            socklen_t len = sizeof(peer);
            int newsockfd = ::accept(_sockfd, CONVERT(&peer), &len);
            if (newsockfd < 0)
                return nullptr;
            *clinet_ip = inet_ntoa(peer.sin_addr);
            *cilent_port = ntohs(peer.sin_port);
            return new TcpSokcet(newsockfd);
        }

        bool ConnectSocket(const std::string &serverip, uint16_t serverport) override
        {
            sockaddr_in server;
            bzero(&server, sizeof(server));
            server.sin_family = AF_INET;
            server.sin_port = htons(serverport);
            server.sin_addr.s_addr = inet_addr(serverip.c_str());
            int n = ::connect(_sockfd, CONVERT(&server), sizeof(server));
            if (n < 0)
                return false;
            else
                return true;
        }

        void SetSockfd(int sockfd) override
        {
            _sockfd = sockfd;
        }

        int GetSockfd() override
        {
            return _sockfd;
        }

        void ColseSockfd() override
        {
            if (_sockfd > default_sockfd)
                close(_sockfd);
        }

        void Send(std::string &send_str) override
        {
            send(_sockfd, send_str.c_str(), send_str.size(), 0);
        }

        bool Recv(std::string *buffer, int size) override
        {
            char inbuffer[size];
            ssize_t n = recv(_sockfd, inbuffer, size - 1, 0);
            if (n > 0)
            {
                inbuffer[n] = 0;
                *buffer += inbuffer;
                return true;
            }
            return false;
        }

    private:
        int _sockfd;
    };
}